package com.example.djlocrspringboot.djlocr.djl;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;

import com.example.djlocrspringboot.djlocr.djl.utils.common.RotatedBox;
import com.example.djlocrspringboot.djlocr.djl.utils.detection.OcrV3Detection;
import com.example.djlocrspringboot.djlocr.djl.utils.recognition.OcrV3Recognition;

import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.ndarray.NDList;
import ai.djl.opencv.OpenCVImageFactory;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;

/**
 * OCR V3模型 文字识别. 支持文本有旋转角度
 *
 * @author Calvin
 * @date 2022-10-07
 * @email 179209347@qq.com
 */
@Component
public final class OcrV3RecognitionExample {
    
    private static final Logger logger = LoggerFactory.getLogger(OcrV3RecognitionExample.class);
    
    private OcrV3RecognitionExample() {
    }
    
    // static OcrV3Detection detection = new OcrV3Detection();
    // static OcrV3Recognition recognition = new OcrV3Recognition();
    // static Predictor<Image, NDList> detector;
    // static Predictor<Image, String> recognizer;
    // public static void init(){
    //     try {
    //         ZooModel detectionModel = ModelZoo.loadModel(detection.detectCriteria());
    //         detector = detectionModel.newPredictor();
    //         ZooModel recognitionModel = ModelZoo.loadModel(recognition.recognizeCriteria());
    //         recognizer = recognitionModel.newPredictor();
    //     } catch (IOException e) {
    //         throw new RuntimeException(e);
    //     } catch (ModelNotFoundException e) {
    //         throw new RuntimeException(e);
    //     } catch (MalformedModelException e) {
    //         throw new RuntimeException(e);
    //     }
    // }
    
    public static String djlOcr(byte[] imageByte) {
        
        StringBuffer stringBuffer = new StringBuffer();
        try {
            // Path imageFile = Paths.get("src/test/resources/7.jpg");
            InputStream inputStream = new ByteArrayInputStream(imageByte);
            Image image = OpenCVImageFactory.getInstance().fromInputStream(inputStream);
            
            OcrV3Detection detection = new OcrV3Detection();
            OcrV3Recognition recognition = new OcrV3Recognition();
            try (ZooModel detectionModel = ModelZoo.loadModel(detection.detectCriteria());
                 Predictor<Image, NDList> detector = detectionModel.newPredictor();
                 ZooModel recognitionModel = ModelZoo.loadModel(recognition.recognizeCriteria());
                 Predictor<Image, String> recognizer = recognitionModel.newPredictor()) {
                long timeInferStart = System.currentTimeMillis();
                List<RotatedBox> detections = recognition.predict(image, detector, recognizer);
                long timeInferEnd = System.currentTimeMillis();
                System.out.println("ocr one image time: " + (timeInferEnd - timeInferStart));
                
                for (RotatedBox result : detections) {
                    stringBuffer.append(result.getText());
                }
                System.out.println(stringBuffer);
                
                // BufferedImage bufferedImage = OpenCVUtils.mat2Image((org.opencv.core.Mat) image.getWrappedImage());
                // for (RotatedBox result : detections) {
                //     ImageUtils.drawImageRectWithText(bufferedImage, result.getBox(), result.getText());
                // }
                // image = ImageFactory.getInstance().fromImage(OpenCVUtils.image2Mat(bufferedImage));
                // ImageUtils.saveImage(image, "ocr_result.png", "build/output");
                // logger.info("{}", detections);
                return stringBuffer.toString();
            } catch (ModelNotFoundException e) {
                throw new RuntimeException(e);
            } catch (MalformedModelException e) {
                throw new RuntimeException(e);
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        } catch (TranslateException e) {
            throw new RuntimeException(e);
        }
    }
}
